import TLSM_method as Tm
from spectral_aggregation import mean_adj_link_prediction
from hosvd_tucker import hosvd_tucker_link_prediction
from multiprocessing import Pool
import time
import numpy as np
import tensorly as tl
tl.set_backend('numpy')

Tensor = np.load('wat_edge.npy')
Node_set = np.load('wat_node.npy')
Node_set[58] = "Cote d'Ivoire"
Tensor[Tensor > 0] = 1

M = 32

layer_deg = Tensor.sum(axis=(0, 1))
index = layer_deg.argsort()
index = list(index)
index.reverse()
index = index[:M]
NT = Tensor[:, :, index]

deg = NT.sum(axis=(1, 2))
ind = np.where(deg/M > 9)[0]
# Node_set = Node_set[ind]
NT = NT[ind]
NT = NT[:, ind, :]


"""
deg = Tensor.sum(axis=(1, 2))
ind = np.where(deg/364 > 9)[0]
Node_set = Node_set[ind]
NT = Tensor[ind]
NT = NT[:, ind, :]
"""

K = 6
fraction_of_training = 0.8
number_of_iter_learn = 1800
Repetition = 50


def multi_processing(parameter):
    A = parameter[0]
    K = parameter[1]
    fraction_of_training = parameter[2]
    number_of_iter_learn = parameter[3]
    self = Tm.TLSM(A, K, fraction_of_training)
    self.Num_ite = number_of_iter_learn
    labels = self.training()
    link_prediciton_error = self.Get_link_prediciton_error(self.alpha, self.beta)
    return link_prediciton_error


if __name__ == '__main__':
    pool = Pool(int(10))
    time_start = time.time()
    link_pred_self = []
    parameterlist = []
    for j in range(10):
        parameterlist.append([NT, K, fraction_of_training, number_of_iter_learn])
    for i in range(5):
        result = pool.map(multi_processing, parameterlist)
        for j in range(10):
            link_pred_self.append(result[j])
    link_pred_self = np.array(link_pred_self)
    print("Link prediction error by TLSM over", Repetition, "independent replications is", link_pred_self.mean(),
          "with standard error", link_pred_self.std()/np.sqrt(Repetition))

    link_pred_mean_adj = []
    parameterlist = []
    for j in range(10):
        parameterlist.append([NT, K, fraction_of_training])
    for i in range(5):
        result_mean_adj = pool.map(mean_adj_link_prediction, parameterlist)
        for j in range(10):
            link_pred_mean_adj.append(result_mean_adj[j])
    link_pred_mean_adj = np.array(link_pred_mean_adj)
    print("Link prediction error by mean_adj over", Repetition, "independent replications is",
          link_pred_mean_adj.mean(), "with standard error", link_pred_mean_adj.std()/np.sqrt(Repetition))

    link_pred_tucker = []
    for i in range(5):
        result_tucker = pool.map(hosvd_tucker_link_prediction, parameterlist)
        for j in range(10):
            link_pred_tucker.append(result_tucker[j])
    link_pred_tucker = np.array(link_pred_tucker)
    print("Link prediction error by HOSVD-Tucker over", Repetition, "independent replications is",
          link_pred_tucker.mean(), "with standard error", link_pred_tucker.std()/np.sqrt(Repetition))

    pool.close()
    pool.join()

